import numpy as np
import time
from copy import copy
from env import MatrixEnv
from utils import ValueTable, Policies

def policy_improve(env, values, policies):
    print('\n===== Policy Improve =====')
    policy_stable = True

    for state in env.states:
        old_act = policies.sample(state)

        # calculate new policy execution
        actions = env.actions
        value = [0] * len(env.actions)

        for i, action in enumerate(actions):
            env.set_state(state)
            _, _, rewards, next_states = env.step(action)
            next_values = values.get(list(next_states))
            td_values = list(map(lambda x, y: x + env.gamma * y, rewards, next_values))
            prob = [1 / len(next_states)] * len(next_states)

            value[i] = sum(map(lambda x, y: x * y, prob, td_values))

        # action selection
        new_act = actions[np.argmax(value)]

        # greedy update policy
        new_policy = [0.] * env.action_space
        new_policy[new_act] = 1.
        policies.update(state, new_policy)

        if old_act != new_act:
            policy_stable = False

    return policy_stable

def value_iter(env, values, upper_bound):
    print('===== Value Iteration =====')
    delta = upper_bound + 1.
    states = copy(env.states)

    iteration = 0

    while delta >= upper_bound:
        delta = 0

        for s in states:
            v = values.get(s)

            # get new value
            actions = env.actions
            vs = [0] * len(actions)

            for i, action in enumerate(actions):
                env.set_state(s)
                _, _, rewards, next_states = env.step(action)
                td_values = list(map(lambda x, y: x + env.gamma * y, rewards, values.get(next_states)))

                vs[i] = np.mean(td_values)

            values.update(s, max(vs)) # 把所有动作全试一遍，取最大的值函数，而不是和策略迭代那样去采样动作
            delta = max(delta, abs(v - values.get(s)))

        iteration += 1
        print('\r> iteration: {} delta: {}'.format(iteration, delta), end="", flush=True)

    return

env = MatrixEnv(width=8, height=8)  # try different word size
policies = Policies(env)
values = ValueTable(env)
upper_bound = 1e-4

start = time.time()
value_iter(env, values, upper_bound)
_ = policy_improve(env, values, policies)   #这一步是根据迭代好的最终的值函数去更新策略
end = time.time()

print('\n[time consumption] {}s'.format(end - start))
# print("===== Render =====")
env.reset()
done = False
rewards = 0
step = 0
while not done:
    act_index = policies.sample(env.state)
    _, done, reward, next_state = env.step(env.actions[act_index])
    rewards += sum(reward)
    step += 1

print('Evaluation: [reward] {} [step] {}'.format(rewards, step))